#include "database_preparedstatement_mysql.h"

#include <sstream>

#include "log.h"
#include "report.h"
#include "database_mysql_wrap.h"
#include "database_preparedstatement.h"

namespace afcore {
namespace database {

/// @brief  断言参数位置是否错误处理
/// @param  stmt_index      预处理状态位置
/// @param  index           参数位置
/// @param  param_count     参数数量
/// @return 失败
static bool ParamenterIndexAssertFail(uint32_t stmt_index, uint8_t index, uint32_t param_count) {
  LOG_ERROR("Attempted to bind parameter {}{} on a PreparedStatement {} (statement has only {} parameters)",
            static_cast<uint32_t>(index) + 1, (1 == index ? "st" : (2 == index ? "nd" : (3 == index ? "rd" : "nd"))), stmt_index, index);
  return false;
}

/// @brief  设置mysql绑定参数值
/// @param  param           mysql绑定参数
/// @param  type            字段类型
/// @param  value           绑定类型值 泛型
/// @param  length          字段数据长度
/// @param  is_unsigned     是否为无符号
static void SetParameterValue(SMysqlBind* param, enum_field_types type, const void* value, uint32_t length, bool is_unsigned) {
  param->buffer_type = type;
  delete[] static_cast<char*>(param->buffer);
  param->buffer = new char[length];
  param->buffer_length = 0;
  param->is_null_value = 0;
  param->length = nullptr;  // 只有string 不为空
  param->is_unsigned = is_unsigned;

  memcpy(param->buffer, value, length);
}

CPreparedStatementMysql::CPreparedStatementMysql(SMysqlStmt* stmt, std::string query_string)
  : stmt_mysql_(stmt)
  , query_string_(std::move(query_string)) {
  param_count_ = mysql_stmt_param_count(stmt);
  params_set_.assign(param_count_, false);
  bind_ = new SMysqlBind[param_count_];
  memset(bind_, 0, sizeof(SMysqlBind) * param_count_);

  // 如果bool_tmp 设置成1， 将导致mysql_stmt_store_result() 更新元数据SMysqlField->max_length的值
  RMysqlBool bool_tmp = RMysqlBool(1);
  mysql_stmt_attr_get(stmt, STMT_ATTR_UPDATE_MAX_LENGTH, &bool_tmp);
}

CPreparedStatementMysql::~CPreparedStatementMysql() {
  ClearParameters();
  if (stmt_mysql_->bind_param_done) {
    delete[] stmt_mysql_->bind->length;
    delete[] stmt_mysql_->bind->is_null;
  }
  mysql_stmt_close(stmt_mysql_);
  delete[] bind_;
}

void CPreparedStatementMysql::SetNull(uint8_t index) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  param->buffer_type = MYSQL_TYPE_NULL;
  delete[] static_cast<char*>(param->buffer);
  param->buffer = nullptr;
  param->buffer_length = 0;
  param->is_null_value = 0;
  delete param->length;
  param->length = nullptr;
}

void CPreparedStatementMysql::SetBool(uint8_t index, const bool value) {
  SetUInt8(index, value ? 1 : 0);
}

void CPreparedStatementMysql::SetUInt8(uint8_t index, uint8_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_TINY, &value, sizeof(uint8_t), true);
}

void CPreparedStatementMysql::SetUInt16(uint8_t index, uint16_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_SHORT, &value, sizeof(uint16_t), true);
}

void CPreparedStatementMysql::SetUInt32(uint8_t index, uint32_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_LONG, &value, sizeof(uint32_t), true);
}

void CPreparedStatementMysql::SetUInt64(uint8_t index, uint64_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_LONGLONG, &value, sizeof(uint64_t), true);
}

void CPreparedStatementMysql::SetInt8(uint8_t index, int8_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_TINY, &value, sizeof(int8_t), false);
}

void CPreparedStatementMysql::SetInt16(uint8_t index, int16_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_SHORT, &value, sizeof(int16_t), false);
}

void CPreparedStatementMysql::SetInt32(uint8_t index, int32_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_LONG, &value, sizeof(int32_t), false);
}

void CPreparedStatementMysql::SetInt64(uint8_t index, int64_t value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_LONGLONG, &value, sizeof(int64_t), false);
}

void CPreparedStatementMysql::SetFloat(uint8_t index, const float value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_FLOAT, &value, sizeof(float), (value > 0.0f));
}

void CPreparedStatementMysql::SetDouble(uint8_t index, const double value) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  SetParameterValue(param, MYSQL_TYPE_DOUBLE, &value, sizeof(double), (value > 0.0f));
}

void CPreparedStatementMysql::SetBinary(uint8_t index, const std::vector<uint8_t>& value, bool is_string) {
  AssertValidIndex(index);
  params_set_[index] = true;
  SMysqlBind* param = &bind_[index];
  uint32_t len = static_cast<uint32_t>(value.size());
  param->buffer_type = MYSQL_TYPE_BLOB;
  delete[] static_cast<char*>(param->buffer);
  param->buffer = new char[len];
  param->buffer_length = len;
  param->is_null_value = 0;
  delete param->length;
  param->length = new unsigned long(len);
  if (is_string) {
    *param->length -= 1;
    param->buffer_type = MYSQL_TYPE_VAR_STRING;
  }

  memcpy(param->buffer, value.data(), len);
}

void CPreparedStatementMysql::ClearParameters() {
  for (uint32_t i = 0; i < param_count_; ++i) {
    delete bind_[i].length;
    bind_[i].length = nullptr;
    delete[] static_cast<char*>(bind_[i].buffer);
    bind_[i].buffer = nullptr;
    params_set_[i] = false;
  }
}

void CPreparedStatementMysql::AssertValidIndex(uint8_t index) {
  DBG_ASSERT(index < param_count_ || ParamenterIndexAssertFail(stmt_->index_, index, param_count_));

  if (params_set_[index]) {
    LOG_ERROR("Prepared Statement (id: {}) trying to bind value on already bound index ({}).", stmt_->index_, index);
  }
}

std::string CPreparedStatementMysql::GetQueryString() const {
  std::string query_string(query_string_);

  size_t pos = 0;
  for (uint32_t i = 0; i < stmt_->statement_data_.size(); ++i) {
    pos = query_string.find('?', pos);
    std::stringstream ss;

    switch(stmt_->statement_data_[i].type) {
      case kPreparedStatementValueType_Bool:
        {
          ss << static_cast<uint16_t>(stmt_->statement_data_[i].data.boolean);
        }
        break;
      case kPreparedStatementValueType_Ui8:
        {
          ss << static_cast<uint16_t>(stmt_->statement_data_[i].data.ui8);
        }
        break;
      case kPreparedStatementValueType_Ui16:
      {
        ss << stmt_->statement_data_[i].data.ui16;
      }
        break;
      case kPreparedStatementValueType_Ui32:
      {
        ss << stmt_->statement_data_[i].data.ui32;
      }
        break;
      case kPreparedStatementValueType_Ui64:
      {
        ss << stmt_->statement_data_[i].data.ui64;
      }
        break;
      case kPreparedStatementValueType_I8:
      {
        ss << static_cast<int16_t>(stmt_->statement_data_[i].data.i8);
      }
        break;
      case kPreparedStatementValueType_I16:
      {
        ss << stmt_->statement_data_[i].data.i16;
      }
        break;
      case kPreparedStatementValueType_I32:
      {
        ss << stmt_->statement_data_[i].data.i32;
      }
        break;
      case kPreparedStatementValueType_I64:
      {
        ss << stmt_->statement_data_[i].data.i64;
      }
        break;
      case kPreparedStatementValueType_Float:
      {
        ss << stmt_->statement_data_[i].data.f;
      }
        break;
      case kPreparedStatementValueType_Double:
      {
        ss << stmt_->statement_data_[i].data.d;
      }
        break;
      case kPreparedStatementValueType_String:
      {
        ss << '\'' << reinterpret_cast<const char*>(stmt_->statement_data_[i].binary.data()) << '\'';
      }
        break;
      case kPreparedStatementValueType_Binary:
      {
        ss << "BINARY";
      }
        break;
      case kPreparedStatementValueType_Null:
      {
        ss << "NULL";
      }
        break;
    }

    std::string replace_str = ss.str();
    query_string.replace(pos, 1, replace_str);
    pos += replace_str.length();
  }
  return query_string;
}

} // !namespace database
} // !namespace afcore